from sklearn import model_selection
import numpy as np
from common.utils.mlmodel_util import get_model_info_sklearn, preprocess, eval_metric, get_model_info_auto_sklearn
from common.utils.model_util import load_model
from common.utils.format_util import dup_name_handler
from config.auto_sklearn_config import AUTO_SKLEARN_CONFIG
from common.log import log_handler

log = log_handler.LogHandler().get_log()

try:
    from autosklearn.classification import AutoSklearnClassifier
except ImportError:
    log.error("haven't install auto-sklearn, so you cannot use auto-sklearn-classifier")


def predict(dataframe, para_list, output):
    # --------------------load the model-----------------
    model_save_path, label_list, feature_list = get_model_info_sklearn(para_list["model_id"])
    model = load_model(model_save_path)

    # --------------------make predictions----------------
    X = preprocess(dataframe, para_list, feature_list)
    pred_raw = model.predict(X)
    pred = [label_list[p] for p in pred_raw]

    output_cols = "_classification_"
    output_cols = dup_name_handler(output_cols, para_list["feature_col"] + [para_list["label_col"]])
    output["result"]["output_params"]["output_cols"] = output_cols
    dataframe[output_cols] = pred
    return dataframe


def train(dataframe, para_list, record, output):
    feature_list = []
    label_col = para_list["label_col"]
    y_tmp = dataframe[label_col].values
    if type(y_tmp[0]) == str:
        nonan_index = y_tmp != 'nan'
    else:
        nonan_index = ~np.isnan(y_tmp.astype(float))
    dataframe = dataframe[nonan_index]
    X = preprocess(dataframe, para_list, feature_list)
    labels = dataframe[label_col].values
    label_list = list(np.unique(np.squeeze(labels)))
    label_list.sort()
    y = [label_list.index(label) for label in labels]
    X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y,
                                                                        test_size=para_list["split_rate"][1],
                                                                        random_state=0)

    # ---------------------model fit---------------------
    config = AUTO_SKLEARN_CONFIG
    config['time_left_for_this_task'] = para_list['timeout']
    config['include_estimators'] = para_list['include']
    model = AutoSklearnClassifier().set_params(**config)

    model.fit(X_train, y_train)
    model_info = get_model_info_auto_sklearn(model, "classifier")
    if type(model_info) == str:
        output["error"] = model_info
        return model, dataframe

    # ---------------------model validation---------------
    y_pred = model.predict(X_test)
    y_pred_res = model.predict(X)
    y_pred_all = y_pred
    y_test_all = y_test
    for label in label_list:
        if label not in y_pred:
            y_pred_all = model.predict(X)
            y_test_all = y
            break
    # ---------------------get metrics--------------------
    if len(label_list) < 1000:
        compute_confusion_matrix = 1
    else:
        compute_confusion_matrix = 0
    accuracy_test, precision_test, recall_test, F1_test, confusion_matrix = eval_metric(y_pred_all.astype(int),
                                                                                        np.array(y_test_all),
                                                                                        compute_confusion_matrix)

    # ---------------------record info----------------------
    other_info = {
        "feature_list": feature_list,
        "feature_col": para_list["feature_col"],
        "model_info": model_info,
        "label_list": label_list,
        "accuracy_test": accuracy_test,
        "Fmeasure_test": F1_test,
        "precision_test": precision_test,
        "recall_test": recall_test,
        "confusion_matrix": confusion_matrix.tolist()
    }
    record["other_info"] = other_info
    log.info(other_info)

    y_pred_res = [label_list[p] for p in y_pred_res]
    output_cols = "_classification_"
    dataframe[output_cols] = y_pred_res
    return model, dataframe
